package com.trade.rws.group;
import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.SQLWarning;
import java.sql.Statement;
import java.util.LinkedList;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import com.trade.rws.dbselect.DBSelector.AbstractDataSourceTryer;
import com.trade.rws.dbselect.DBSelector.DataSourceTryer;
import com.trade.rws.rule.SqlType;
import com.trade.rws.util.GroupHintParser;
import com.trade.rws.util.SQLParser;



public class TGroupStatement implements Statement {
	private static final Log log = LogFactory.getLog(TGroupStatement.class);

	protected TGroupConnection tGroupConnection;
	protected TGroupDataSource tGroupDataSource;
	protected int retryingTimes;

	public TGroupStatement(TGroupDataSource tGroupDataSource, TGroupConnection tGroupConnection) {
		this.tGroupDataSource = tGroupDataSource;
		this.tGroupConnection = tGroupConnection;
		this.retryingTimes = tGroupDataSource.getRetryingTimes();
	}

	/* ========================================================================
	 * 下层(有可能不是真正的)Statement的持有，getter/setter包权限
	 * ======================================================================*/
	private Statement baseStatement;

	/**
	 * 设置在底层执行的具体的Statement
	 * 如果前面的baseStatement未关，则先关闭
	 * @param baseStatement
	 */
	void setBaseStatement(Statement baseStatement) {
		if (this.baseStatement != null) {
			try {
				this.baseStatement.close();
			} catch (SQLException e) {
				log.error("close baseStatement failed.", e);
			}
		}
		this.baseStatement = baseStatement;
	}

	/**
	 * query time out . 超时时间，如果超时时间不为0。那么超时应该被set到真正的query中。
	 */
	protected int queryTimeout = 0;

	protected int fetchSize;

	protected int maxRows;

	/**
	 * 经过计算后的结果集，允许使用 getResult函数调用.
	 *
	 * 一个statement只允许有一个结果集
	 */
	protected ResultSet currentResultSet;
	/**
	 * 更新计数，如果执行了多次，那么这个值只会返回最后一次执行的结果。 如果是一个query，那么返回的数据应该是-1
	 */
	protected int updateCount;

	protected int resultSetType = ResultSet.TYPE_FORWARD_ONLY;
	protected int resultSetConcurrency = ResultSet.CONCUR_READ_ONLY;

	//jdbc规范中未指明resultSetHoldability的默认值，要设成ResultSet.CLOSE_CURSORS_AT_COMMIT吗?
	//TODO 统一设成-1吗?
	protected int resultSetHoldability = -1;

	public boolean execute(String sql) throws SQLException {
		return executeInternal(sql, -1, null, null);
	}

	public boolean execute(String sql, int autoGeneratedKeys) throws SQLException {
		return executeInternal(sql, autoGeneratedKeys, null, null);
	}

	public boolean execute(String sql, int[] columnIndexes) throws SQLException {
		return executeInternal(sql, -1, columnIndexes, null);
	}

	public boolean execute(String sql, String[] columnNames) throws SQLException {
		return executeInternal(sql, -1, null, columnNames);
	}

	//jdbc规范: 返回true表示executeQuery，false表示executeUpdate
	private boolean executeInternal(String sql, int autoGeneratedKeys, int[] columnIndexes, String[] columnNames)
			throws SQLException {

		SqlType sqlType = SQLParser.getSqlType(sql);
		if (sqlType == SqlType.SELECT || sqlType == SqlType.SELECT_FOR_UPDATE || sqlType == SqlType.SHOW) {
			executeQuery(sql);
			return true;
		} else if (sqlType == SqlType.INSERT || sqlType == SqlType.UPDATE || sqlType == SqlType.DELETE||sqlType == SqlType.REPLACE||sqlType == SqlType.TRUNCATE
				|| sqlType == SqlType.CREATE|| sqlType== SqlType.DROP|| sqlType == SqlType.LOAD|| sqlType== SqlType.MERGE) {
			if (autoGeneratedKeys == -1 && columnIndexes == null && columnNames == null) {
				executeUpdate(sql);
			} else if (autoGeneratedKeys != -1) {
				executeUpdate(sql, autoGeneratedKeys);
			} else if (columnIndexes != null) {
				executeUpdate(sql, columnIndexes);
			} else if (columnNames != null) {
				executeUpdate(sql, columnNames);
			} else {
				executeUpdate(sql);
			}
			return false;
		} else {
			throw new SQLException("only select, insert, update, delete,replace,truncate,create,drop,load,merge sql is supported");
		}
	}

	/* ========================================================================
	 * executeUpdate逻辑
	 * ======================================================================*/
	public int executeUpdate(String sql) throws SQLException {
		return executeUpdateInternal(sql, -1, null, null);
	}

	public int executeUpdate(String sql, int autoGeneratedKeys) throws SQLException {
		return executeUpdateInternal(sql, autoGeneratedKeys, null, null);
	}

	public int executeUpdate(String sql, int[] columnIndexes) throws SQLException {
		return executeUpdateInternal(sql, -1, columnIndexes, null);
	}

	public int executeUpdate(String sql, String[] columnNames) throws SQLException {
		return executeUpdateInternal(sql, -1, null, columnNames);
	}

	private int executeUpdateInternal(String sql, int autoGeneratedKeys, int[] columnIndexes, String[] columnNames)
			throws SQLException {
		checkClosed();
		ensureResultSetIsEmpty();
		recordWriteTimes();
		increaseConcurrentWrite();
		Connection conn = tGroupConnection.getBaseConnection(sql,false);
		try {
			if (conn != null){
				sql=GroupHintParser.removeTddlGroupHint(sql);
				this.updateCount=executeUpdateOnConnection(conn, sql, autoGeneratedKeys, columnIndexes, columnNames);
			    return this.updateCount;
			}
			else{
				Integer dataSourceIndex = GroupHintParser
				    .convertHint2Index(sql);
				sql=GroupHintParser.removeTddlGroupHint(sql);
				if (dataSourceIndex < 0) {
					dataSourceIndex = ThreadLocalDataSourceIndex.getIndex();
				}
				this.updateCount=this.tGroupDataSource.getDBSelector(false).tryExecute(executeUpdateTryer, retryingTimes, sql,
						autoGeneratedKeys, columnIndexes, columnNames,dataSourceIndex);
				return this.updateCount;
			}
		} catch (SQLException e) {
			throw e;
		}finally {
			decreaseConcurrentWrite();
		}
	}

	private int executeUpdateOnConnection(Connection conn, String sql, int autoGeneratedKeys, int[] columnIndexes,
			String[] columnNames) throws SQLException {
		
		
		Statement stmt = createStatementInternal(conn, false);

		if (autoGeneratedKeys == -1 && columnIndexes == null && columnNames == null) {
			return stmt.executeUpdate(sql);
		} else if (autoGeneratedKeys != -1) {
			return stmt.executeUpdate(sql, autoGeneratedKeys);
		} else if (columnIndexes != null) {
			return stmt.executeUpdate(sql, columnIndexes);
		} else if (columnNames != null) {
			return stmt.executeUpdate(sql, columnNames);
		} else {
			return stmt.executeUpdate(sql);
		}
	}
	
	

	private DataSourceTryer<Integer> executeUpdateTryer = new AbstractDataSourceTryer<Integer>() {
		public Integer tryOnDataSource(DataSourceWrapper dsw, Object... args) throws SQLException {
			
			Connection conn = TGroupStatement.this.tGroupConnection.createNewConnection(dsw, false);
			return executeUpdateOnConnection(conn, (String) args[0], (Integer) args[1], (int[]) args[2],
					(String[]) args[3]);
		}
	};

	/**
	 * 会调用setBaseStatement以关闭已有的Statement
	 */
	private Statement createStatementInternal(Connection conn, boolean isBatch) throws SQLException {
		Statement stmt;
		if (isBatch)
			stmt = conn.createStatement();
		else {
			int resultSetHoldability = this.resultSetHoldability;
			if (resultSetHoldability == -1) //未调用过setResultSetHoldability
				resultSetHoldability = conn.getHoldability();
			stmt = conn.createStatement(this.resultSetType, this.resultSetConcurrency, resultSetHoldability);
		}

		setBaseStatement(stmt); //会关闭已有的Statement
		stmt.setQueryTimeout(queryTimeout); //这句也有可能抛出异常，放在最后
		stmt.setFetchSize(fetchSize);
		stmt.setMaxRows(maxRows);
		return stmt;
	}

	/* ========================================================================
	 * executeBatch
	 * ======================================================================*/
	protected List<String> batchedArgs;

	public void addBatch(String sql) throws SQLException {
		checkClosed();
		if (batchedArgs == null) {
			batchedArgs = new LinkedList<String>();
		}
		if (sql != null) {
			batchedArgs.add(sql);
		}
	}

	public void clearBatch() throws SQLException {
		checkClosed();
		if (batchedArgs != null) {
			batchedArgs.clear();
		}
	}

	public int[] executeBatch() throws SQLException {
		try {
			checkClosed();
			ensureResultSetIsEmpty();
			recordWriteTimes();
			increaseConcurrentWrite();
			if (batchedArgs == null || batchedArgs.isEmpty()) {
				return new int[0];
			}

			Connection conn = tGroupConnection.getBaseConnection(null,false);
			if (conn != null) {
				// 如果当前已经有连接,则不做任何重试。对于更新来说，不管有没有事务，
				// 用户总期望getConnection获得连接之后，后续的一系列操作都在这同一个库，同一个连接上执行
				return executeBatchOnConnection(conn, this.batchedArgs);
			} else {
				return tGroupDataSource.getDBSelector(false).tryExecute(null, executeBatchTryer, retryingTimes);
			}
		} finally {
			if (batchedArgs != null){
				batchedArgs.clear();
			}
			decreaseConcurrentWrite();
		}
	}

	private DataSourceTryer<int[]> executeBatchTryer = new AbstractDataSourceTryer<int[]>() {
		public int[] tryOnDataSource(DataSourceWrapper dsw, Object... args) throws SQLException {
			Connection conn = TGroupStatement.this.tGroupConnection.createNewConnection(dsw, false);
			return executeBatchOnConnection(conn, TGroupStatement.this.batchedArgs);
		}
	};

	private int[] executeBatchOnConnection(Connection conn, List<String> batchedSqls) throws SQLException {
		Statement stmt = createStatementInternal(conn, true);
		for (String sql : batchedSqls) {
			stmt.addBatch(sql);
		}
		return stmt.executeBatch();
	}

	/* ========================================================================
	 * 关闭逻辑
	 * ======================================================================*/
	protected boolean closed; //当前statment 是否是关闭的

	public void close() throws SQLException {
		close(true);
	}

	void close(boolean removeThis) throws SQLException {
		if (closed) {
			return;
		}
		closed = true;

		try {
			if (currentResultSet != null)
				currentResultSet.close();
		} catch (SQLException e) {
			log.warn("Close currentResultSet failed.", e);
		} finally {
			currentResultSet = null;
		}

		try {
			if (this.baseStatement != null)
				this.baseStatement.close();
		} finally {
			this.baseStatement = null;
			if (removeThis)
				tGroupConnection.removeOpenedStatements(this);
		}
	}

	protected void checkClosed() throws SQLException {
		if (closed) {
			throw new SQLException("No operations allowed after statement closed.");
		}
	}

	/**
	 * 如果新建了查询，那么上一次查询的结果集应该被显示的关闭掉。这才是符合jdbc规范的
	 *
	 * @throws SQLException
	 */
	protected void ensureResultSetIsEmpty() throws SQLException {

		if (currentResultSet != null) {
			//log.debug("result set is not null,close current result set");
			try {
				currentResultSet.close();
			} catch (SQLException e) {
				log.error("exception on close last result set . can do nothing..", e);
			} finally {
				// 最终要显示的关闭它
				currentResultSet = null;
			}
		}

	}

	/* ========================================================================
	 * executeQuery 查询逻辑
	 * ======================================================================*/
	public ResultSet executeQuery(String sql) throws SQLException {
		checkClosed();
		ensureResultSetIsEmpty();
		recordReadTimes();
		increaseConcurrentRead();
		boolean gotoRead = SqlType.SELECT.equals(SQLParser.getSqlType(sql)) && tGroupConnection.getAutoCommit();
		Connection conn = tGroupConnection.getBaseConnection(sql,gotoRead);
		try {
			if (conn != null){
				sql=GroupHintParser.removeTddlGroupHint(sql);
				return executeQueryOnConnection(conn, sql);
			}else{
				// hint优先
				Integer dataSourceIndex = GroupHintParser.convertHint2Index(sql);
				sql=GroupHintParser.removeTddlGroupHint(sql);
				if (dataSourceIndex < 0) {
					dataSourceIndex = ThreadLocalDataSourceIndex.getIndex();
				}
				return this.tGroupDataSource.getDBSelector(gotoRead).tryExecute(executeQueryTryer, retryingTimes, sql,dataSourceIndex);
			}
		} catch (SQLException e) {
			throw e;
		}finally {
			decreaseConcurrentRead();
		}
	}

	protected ResultSet executeQueryOnConnection(Connection conn, String sql) throws SQLException {
		Statement stmt = createStatementInternal(conn, false);
		this.currentResultSet = stmt.executeQuery(sql);
		return this.currentResultSet;
	}

	protected DataSourceTryer<ResultSet> executeQueryTryer = new AbstractDataSourceTryer<ResultSet>() {
		public ResultSet tryOnDataSource(DataSourceWrapper dsw, Object... args) throws SQLException {
			
			String sql = (String) args[0];
			Connection conn = TGroupStatement.this.tGroupConnection.createNewConnection(dsw, true);
			return executeQueryOnConnection(conn, sql);
		}
	};

	public SQLWarning getWarnings() throws SQLException {
		checkClosed();
		if (baseStatement != null)
			return baseStatement.getWarnings();
		return null;
	}

	public void clearWarnings() throws SQLException {
		checkClosed();
		if (baseStatement != null)
			baseStatement.clearWarnings();
	}

	/* ========================================================================
	 * 以下为简单支持的方法
	 * ======================================================================*/
	/**
	 * 貌似是只有存储过程中会出现多结果集 因此不支持
	 */
	protected boolean moreResults;

	public boolean getMoreResults() throws SQLException {
		return moreResults;
	}

	public int getQueryTimeout() throws SQLException {
		return queryTimeout;
	}

	public void setQueryTimeout(int queryTimeout) throws SQLException {
		this.queryTimeout = queryTimeout;
	}

	public ResultSet getResultSet() throws SQLException {
		return currentResultSet;
	}

	public int getUpdateCount() throws SQLException {
		return updateCount;
	}

	public int getResultSetConcurrency() throws SQLException {
		return resultSetConcurrency;
	}

	public int getResultSetHoldability() throws SQLException {
		return resultSetHoldability;
	}

	public int getResultSetType() throws SQLException {
		return resultSetType;
	}

	public void setResultSetType(int resultSetType) {
		this.resultSetType = resultSetType;
	}

	public void setResultSetConcurrency(int resultSetConcurrency) {
		this.resultSetConcurrency = resultSetConcurrency;
	}

	public void setResultSetHoldability(int resultSetHoldability) {
		this.resultSetHoldability = resultSetHoldability;
	}

	public Connection getConnection() throws SQLException {
		return tGroupConnection;
	}

	/* ========================================================================
	 * 以下为不支持的方法
	 * ======================================================================*/
	public int getFetchDirection() throws SQLException {
		throw new UnsupportedOperationException("getFetchDirection");
	}

	public int getFetchSize() throws SQLException {
		return this.fetchSize;
	}

	public int getMaxFieldSize() throws SQLException {
		throw new UnsupportedOperationException("getMaxFieldSize");
	}

	public int getMaxRows() throws SQLException {
		return this.maxRows;
	}

	public void setCursorName(String cursorName) throws SQLException {
		throw new UnsupportedOperationException("setCursorName");
	}

	public void setEscapeProcessing(boolean escapeProcessing) throws SQLException {
		throw new UnsupportedOperationException("setEscapeProcessing");
	}

	public boolean getMoreResults(int current) throws SQLException {
		throw new UnsupportedOperationException("getMoreResults");
	}

	public void setFetchDirection(int fetchDirection) throws SQLException {
		throw new UnsupportedOperationException("setFetchDirection");
	}

	public void setFetchSize(int fetchSize) throws SQLException {
		this.fetchSize=fetchSize;
	}

	public void setMaxFieldSize(int maxFieldSize) throws SQLException {
		throw new UnsupportedOperationException("setMaxFieldSize");
	}

	public void setMaxRows(int maxRows) throws SQLException {
		this.maxRows=maxRows;
	}

	public ResultSet getGeneratedKeys() throws SQLException {
		if (this.baseStatement != null)
			return this.baseStatement.getGeneratedKeys();
		else
			throw new SQLException("在调用getGeneratedKeys前未执行过任何更新操作");
		//throw new UnsupportedOperationException("getGeneratedKeys");
	}

	public void cancel() throws SQLException {
		throw new UnsupportedOperationException("cancel");
	}

	public boolean isWrapperFor(Class<?> iface) throws SQLException
	{
		return this.getClass().isAssignableFrom(iface);
	}


	@SuppressWarnings("unchecked")
	public <T> T unwrap(Class<T> iface) throws SQLException
	{
		try
		{
			return (T) this;
		} catch (Exception e)
		{
			throw new SQLException(e);
		}
	}

	public boolean isClosed() throws SQLException
	{
		throw new SQLException("not support exception");
	}

	public void setPoolable(boolean poolable) throws SQLException
	{
		throw new SQLException("not support exception");
	}

	public boolean isPoolable() throws SQLException
	{
		throw new SQLException("not support exception");
	}

	//1.7
//	@Override
	public void closeOnCompletion() throws SQLException {
		throw new SQLException("not support exception");
		
	}

//	@Override
	public boolean isCloseOnCompletion() throws SQLException {
		throw new SQLException("not support exception");
	}
	
	
	protected void recordReadTimes() throws SQLException {
		if (!tGroupDataSource.readFlowControl.allow()) {
			throw new SQLException(tGroupDataSource.readFlowControl.reportExceed());
		}
	}
	
	

	protected void recordWriteTimes() throws SQLException {
		if (!tGroupDataSource.writeFlowControl.allow()) {
			throw new SQLException(tGroupDataSource.writeFlowControl.reportExceed());
		}
	}

	// 增加并发读计数并判断阀值
	protected void increaseConcurrentRead() throws SQLException {
		int maxConcurrentReadRestrict = tGroupDataSource.connectionProperties.maxConcurrentReadRestrict;
		int concurrentReadCount = tGroupDataSource.concurrentReadCount.incrementAndGet();
		if (maxConcurrentReadRestrict != 0) {
			if (concurrentReadCount > maxConcurrentReadRestrict) {
				tGroupDataSource.readTimesReject.incrementAndGet();
				throw new SQLException("maxConcurrentReadRestrict reached , " + maxConcurrentReadRestrict);
			}
		}
	}

	// 增加并发写计数并判断阀值
	protected void increaseConcurrentWrite() throws SQLException {
		int maxConcurrentWriteRestrict = tGroupDataSource.connectionProperties.maxConcurrentWriteRestrict;
		int concurrentWriteCount = tGroupDataSource.concurrentWriteCount.incrementAndGet();
		if (maxConcurrentWriteRestrict != 0) {
			if (concurrentWriteCount > maxConcurrentWriteRestrict) {
				tGroupDataSource.writeTimesReject.incrementAndGet();
				throw new SQLException("maxConcurrentWriteRestrict reached , " + maxConcurrentWriteRestrict);
			}
		}
	}

	// 减少并发读计数
	protected void decreaseConcurrentRead() throws SQLException {
		tGroupDataSource.concurrentReadCount.decrementAndGet();
	}

	// 减少并发写计数
	protected void decreaseConcurrentWrite() throws SQLException {
		tGroupDataSource.concurrentWriteCount.decrementAndGet();
	}

}
